"""RabbitMQ Message Queue helper."""

import contextlib
import functools
import json
import os
import platform
import queue
import random
import ssl
import threading
from urllib import parse
import uuid

import pika
import prometheus_client as prometheus

from . import gitlab
from . import logger
from . import metrics
from . import misc
from . import timer

LOGGER = logger.get_logger(__name__)

RETRY_EXCHANGE_IN = os.environ.get('RABBITMQ_RETRY_EXCHANGE_IN',
                                   'cki.exchange.retry.incoming')
RETRY_EXCHANGE_OUT = os.environ.get('RABBITMQ_RETRY_EXCHANGE_OUT',
                                    'cki.exchange.retry.outgoing')

METRIC_MESSAGE_SENT = prometheus.Counter(
    'cki_message_sent', 'Number of queue messages sent',
    ['routing_key'])
METRIC_MESSAGE_RECEIVED = prometheus.Counter(
    'cki_message_received', 'Number of queue messages received',
    ['routing_key'])
METRIC_MESSAGE_PROCESSED = prometheus.Counter(
    'cki_message_processed', 'Number of queue messages processed')
METRIC_MESSAGE_ERROR = prometheus.Counter(
    'cki_message_error', 'Number of exceptions during queue message handling')
METRIC_LOAD = metrics.LoadIndex(
    'cki_message_load', 'Normalized indicator of the time spent handling a message')
METRIC_TIME = prometheus.Histogram(
    'cki_message_duration_seconds', 'Time spent handling a message',
    buckets=(0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60,
             2 * 60, 5 * 60, 10 * 60, 30 * 60, 60 * 60, '+Inf'))


class QuietNackException(Exception):
    """Quietly requeue message when thrown from message handler."""


class MessageQueue:
    # pylint: disable=too-many-instance-attributes
    """
    RabbitMQ message queue helper.

    Helper to handle queue initialization and message sending.
    MessageQueue.connect() should be used to get a context manager for a
    `pika.channel.Channel`.

    host: RabbitMQ server address. Defaults to the value of the RABBITMQ_HOST
        environment variable or localhost. Multiple hosts can be specified
        in a string separated by whitespace or as a list.
    port: RabbitMQ server port. Defaults to the value of the RABBITMQ_PORT
        environment variable or 5672. With a port of 443 or 5671, uses SSL.
    virtual_host: RabbitMQ virtual host. Defaults to the value of the
    RABBITMQ_VIRTUAL_HOST environment variable or /.
    user: RabbitMQ server user. Defaults to the value of the RABBITMQ_USER
        environment variable or guest.
    password: RabbitMQ server password. Defaults to the value of the
        RABBITMQ_PASSWORD environment variable or guest.
    cafile: ca certificates. Defaults to the value of the RABBITMQ_CAFILE
        environment variable.
    certfile: SSL client private key and corresponding certificate. Defaults
        to the value of the RABBITMQ_CERTFILE environment variable.
    keepalive_s: seconds to keep the channel open after sending the message.
        Defaults to 0.
    dlx_retry: enable Dead Letter Exchange retry. Falls back to leaving messages
        unacked if disabled. Defaults to True.
    """

    def __init__(self, *, host=None, port=None, virtual_host=None, user=None, password=None,
                 connection_params=None, cafile=None, certfile=None,
                 keepalive_s=0, dlx_retry=True):
        # pylint: disable=too-many-arguments,too-many-positional-arguments
        """Init."""
        host = host or os.environ.get('RABBITMQ_HOST', 'localhost')
        port = int(port or misc.get_env_int('RABBITMQ_PORT', 5672))
        virtual_host = virtual_host or os.environ.get('RABBITMQ_VIRTUAL_HOST', '/')
        user = user or os.environ.get('RABBITMQ_USER', 'guest')
        password = password or os.environ.get('RABBITMQ_PASSWORD', 'guest')
        cafile = cafile or os.environ.get('RABBITMQ_CAFILE')
        certfile = certfile or os.environ.get('RABBITMQ_CERTFILE')
        connection_params = dict(connection_params or {})
        if isinstance(host, str):
            host = host.split()
        if (port in (443, 5671) or certfile or cafile) and \
                'ssl_options' not in connection_params:
            connection_params['ssl_options'] = pika.SSLOptions(
                ssl.create_default_context(cafile=cafile))
        if certfile:
            self.credentials = pika.credentials.ExternalCredentials()
            connection_params['ssl_options'].context.load_cert_chain(certfile)
        else:
            self.credentials = pika.PlainCredentials(user, password)
        # https://github.com/fedora-infra/fedora-messaging/issues/440
        if not misc.get_env_bool('RABBITMQ_SSL_VERIFY_X509_STRICT', True):
            connection_params['ssl_options'].context.verify_flags &= ~ssl.VERIFY_X509_STRICT
        self.connection_params = [pika.ConnectionParameters(
            host=h.rstrip('/'), port=port, virtual_host=virtual_host,
            credentials=self.credentials,
            client_properties={'connection_name': platform.node()},
            **connection_params) for h in host]

        self.keepalive_s = keepalive_s
        self._channel_lock = threading.RLock()
        self._channel = None
        self._disconnect_timer = timer.ScheduledTask(
            keepalive_s, self._disconnect
        )
        self.dlx_retry = dlx_retry

        self.msg_logging_env = MessageLoggingEnv()

    def send_message(self, data, queue_name, exchange='', headers=None, priority=None):
        # pylint: disable=too-many-arguments,too-many-positional-arguments
        """
        Send message to queue.

        Encode `data` as json and send it to routing_key=queue_name,
        exchange=exchange.
        """
        # persistent messages
        properties = pika.BasicProperties(delivery_mode=2,
                                          headers=headers,
                                          priority=priority)
        logging_env = {
            'send_message': {
                'body': data,
                'exchange': exchange,
                'headers': headers,
                'routing_key': queue_name,
            }
        }
        with self.connect() as channel:
            body = json.dumps(data)
            with logger.logging_env(logging_env):
                LOGGER.info('Sending message to exchange=%s with routing_key=%s',
                            exchange, queue_name)
            channel.basic_publish(
                body=body, exchange=exchange, routing_key=queue_name,
                properties=properties)
            METRIC_MESSAGE_SENT.labels(queue_name).inc()

    def _queue_declare(self, channel, queue_name, max_priority):
        """
        Declare queue.

        On development environments a disposable queue is created with a random UUID.
        On production/staging environments, a queue name must be specified.

        If dlx_retry is enabled in production/staging mode, a
        cki.queue.retry.{queue_name} is also created.

        Returns the queue name that should be used.
        """
        priority_arguments = {'x-max-priority': max_priority} if max_priority else {}
        if misc.is_production_or_staging():
            # production/staging queue, as durable as possible to not lose msgs
            if self.dlx_retry:
                # If dlx_retry enabled, set x-dead-letter parameters
                channel.queue_declare(
                    queue_name,
                    durable=True,
                    arguments={
                        **priority_arguments,
                        'x-dead-letter-exchange': RETRY_EXCHANGE_IN,
                        'x-dead-letter-routing-key': queue_name
                    }
                )

                # Declare retry queue for DLX based retry system
                channel.queue_declare(
                    f'cki.queue.retry.{queue_name}',
                    durable=True,
                    arguments={
                        'x-dead-letter-exchange': RETRY_EXCHANGE_OUT,
                    }
                )
            else:
                channel.queue_declare(
                    queue_name,
                    durable=True,
                    arguments=priority_arguments,
                )
        else:
            # temporary queue, uuid format for fedora-messaging
            queue_name = str(uuid.uuid4())
            channel.queue_declare(
                queue_name,
                auto_delete=True,
                arguments=priority_arguments,
            )
        return queue_name

    def _queue_bind(self, channel, exchange, queue_name, routing_keys):
        """Bind queue to the exchange."""
        # We're expecting routing_keys to be a list, but str is also ok
        for routing_key in misc.flattened(routing_keys):
            channel.queue_bind(queue_name, exchange,
                               routing_key=routing_key)

        if misc.is_production_or_staging() and self.dlx_retry:
            # Bind retry exchanges
            channel.queue_bind(
                exchange=RETRY_EXCHANGE_IN,
                routing_key=queue_name,
                queue=f'cki.queue.retry.{queue_name}')
            channel.queue_bind(
                exchange=RETRY_EXCHANGE_OUT,
                routing_key=queue_name,
                queue=queue_name)

    @staticmethod
    def _get_routing_key(routing_key, properties):
        """
        Recover original routing_key for DLX requeued message.

        DLX requeuing overrides the original routing_key.
        Restore original routing_key when the message was rejected.
        """
        try:
            return properties.headers['x-death'][-1]['routing-keys'][0]
        except Exception:
            return routing_key

    # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
    def _consume_messages_thread(self, thread_queue, thread_quit,
                                 thread_dead_channel,
                                 exchange: str, routing_keys: list,
                                 queue_name: str,
                                 max_priority: int = 0,
                                 prefetch_count: int = 5,
                                 inactivity_timeout: int = None,
                                 return_on_timeout: bool = True):

        try:
            with self._connect_no_keepalive() as channel:
                if prefetch_count:
                    channel.basic_qos(prefetch_count=prefetch_count)

                queue_name = self._queue_declare(channel, queue_name, max_priority)
                self._queue_bind(channel, exchange, queue_name, routing_keys)

                for method, properties, body in channel.consume(
                        queue_name, inactivity_timeout=inactivity_timeout):
                    if thread_quit.is_set():
                        LOGGER.info('Terminating thread as requested')
                        return
                    if not method:  # inactivity timeout
                        if return_on_timeout:
                            return
                        thread_queue.put((channel, None, None, 'null', None, True))
                        continue

                    routing_key = self._get_routing_key(method.routing_key, properties)

                    LOGGER.info('Received payload from %s (%s)',
                                routing_key, method.delivery_tag)
                    thread_queue.put((channel, method.delivery_tag,
                                      routing_key, body, properties.headers, False))
        except Exception:
            # an exception most likely means that the channel is dead
            LOGGER.info('Marking the channel as dead')
            thread_dead_channel.set()
            LOGGER.exception('Exception in consume messages thread')
            if misc.is_production_or_staging():
                LOGGER.critical('Calling os._exit() in production/staging mode')
                misc.sentry_flush()
                # sending SIGINT would raise KeyboardInterrupt in the main
                # thread and run cleanup handlers, but in some cases does not
                # unblock the main thread
                os._exit(1)  # pylint: disable=protected-access
        finally:
            LOGGER.info('Terminating the consumer')
            thread_queue.put(None)  # terminate consumer

    @staticmethod
    def _add_callback_fn(channel, delivery_tag, ack=True):
        """Add a threadsafe ack/nack callback to the channel."""
        if ack:
            method = functools.partial(channel.basic_ack, delivery_tag)
        else:
            method = functools.partial(channel.basic_nack, delivery_tag, requeue=False)
        channel.connection.add_callback_threadsafe(method)

    @staticmethod
    @METRIC_TIME.time()
    @METRIC_LOAD.context()
    def measured_callback(function, *args, **kwargs):
        """Execute function and populate metrics."""
        METRIC_MESSAGE_RECEIVED.labels(kwargs.get('routing_key', '')).inc()
        try:
            function(*args, **kwargs)
        except Exception:
            METRIC_MESSAGE_ERROR.inc()
            raise
        METRIC_MESSAGE_PROCESSED.inc()

    def _consume_one(self, item, callback, manual_ack):
        channel, delivery_tag, routing_key, body, headers, timeout = item
        add_callback_fn = functools.partial(self._add_callback_fn, channel, delivery_tag)
        try:
            json_body = json.loads(body)
        except Exception:
            LOGGER.exception('Message parsing failure')
            if delivery_tag:
                add_callback_fn(True)  # retrying will not help
            return

        with logger.logging_env({'message': self.msg_logging_env.render(headers, json_body)}):
            LOGGER.debug('Processing payload from %s (%s)', routing_key, delivery_tag)
            try:
                self.measured_callback(
                    callback,
                    body=json_body,
                    routing_key=routing_key,
                    headers=headers,
                    ack_fn=add_callback_fn if manual_ack and delivery_tag else None,
                    timeout=timeout)
            except QuietNackException:
                if misc.is_production_or_staging() and self.dlx_retry and delivery_tag:
                    add_callback_fn(False)
            except Exception:
                if not (misc.is_production_or_staging() and self.dlx_retry):
                    LOGGER.exception('Message handling failure, '
                                     'will be requeued after restart')
                elif delivery_tag:
                    add_callback_fn(False)
                    LOGGER.exception('Message handling failure, '
                                     'will be requeued after some time')
                elif manual_ack:
                    LOGGER.exception('Failure while processing timeout callback, '
                                     'messages might be left unacked and block further delivery.')
                else:
                    LOGGER.exception('Failure while processing timeout callback.')
            else:
                if not manual_ack and delivery_tag:
                    add_callback_fn(True)

    def consume_messages(self, exchange: str, routing_keys: list,
                         callback: callable, *args,
                         manual_ack=False,
                         queue_name: str = None,
                         **kwargs):
        """Endlessly consume messages.

        The callback has the signature (**kwargs) with the following optional
        kwargs:
            body: body
            routing_key: routing key
            headers: headers
            ack_fn: If manual_ack=True, the callback is expected to call ack_fn
                when it wants to acknowledge a message.
            timeout: If return_on_timeout=False, the callback will be called with
                timeout=True.
        The callback should specify the kwargs it is interested in, but also
        have a **_ argument at the end to gobble up any unknown kwargs.
        """
        if misc.is_production_or_staging() and not queue_name:
            raise Exception('A queue name needs to be specified in production/staging')
        thread_queue = queue.Queue()
        thread_quit = threading.Event()
        thread_dead_channel = threading.Event()
        threading.Thread(target=lambda: self._consume_messages_thread(
            thread_queue, thread_quit, thread_dead_channel,
            exchange, routing_keys, queue_name, *args, **kwargs), daemon=True).start()
        try:
            while not thread_dead_channel.is_set():
                item = thread_queue.get()
                if not item:
                    LOGGER.info('Terminating message callback loop because of sentinel')
                    break
                self._consume_one(item, callback, manual_ack)
            LOGGER.info('Left message callback loop')
        finally:
            LOGGER.info('Signaling producer thread to terminate')
            thread_quit.set()  # terminate producer on signal in queue.get()

    def connect(self):
        """Connect to the server and return a channel."""
        if self.keepalive_s:
            return self._connect_and_keepalive()
        return self._connect_no_keepalive()

    def _get_connection(self):
        params = self.connection_params[:]
        random.shuffle(params)
        return pika.BlockingConnection(params)

    @contextlib.contextmanager
    def _connect_no_keepalive(self):
        """Create connection and close it after use."""
        connection = self._get_connection()
        LOGGER.debug('Creating new channel')
        try:
            yield connection.channel()
        finally:
            connection.close()

    @contextlib.contextmanager
    def _connect_and_keepalive(self):
        """Create connection and schedule timer to close it."""
        self._disconnect_timer.cancel()
        with self._channel_lock:
            if not self._channel:
                connection = self._get_connection()
                LOGGER.debug('Creating new channel')
                self._channel = connection.channel()
            try:
                yield self._channel
            except Exception:
                self._disconnect()
                raise
            self._disconnect_timer.start()

    def _disconnect(self):
        """Close the connection."""
        LOGGER.debug('Closing the connection')
        with self._channel_lock:
            if self._channel:
                # there is no use in alerting about already closed channels
                with contextlib.suppress(Exception):
                    self._channel.connection.close()
                self._channel = None


class Message:
    """Webhook message."""

    def __init__(self, payload):
        """Initialize the instance with a preparsed payload."""
        self.payload = payload

    def gl_instance(self):
        """Return a Gitlab API instance."""
        return gitlab.get_instance(self.gitlab_url())

    def gitlab_url(self):
        """Return the GitLab URL."""
        if 'project' in self.payload:
            web_url = self.payload['project']['web_url']
        else:
            web_url = self.payload['repository']['homepage']
        web_url = parse.urlsplit(web_url)
        return parse.urlunparse(web_url[:2] + ('',) * 4)


class MessageLoggingEnv:
    """Logging environment for message callbacks."""

    def __init__(self):
        """Initialize with message data."""
        self.hooks = [
            ('gitlab', self.hook_gitlab),
            ('datawarehouse', self.hook_datawarehouse),
        ]

    def add_hook(self, message_type, callback):
        """
        Add a new hook to the list.

        Messages matching message_type will be processed with the
        specified function to add data to the logging environment.
        """
        self.hooks.append((message_type, callback))

    def render(self, headers, body):
        """
        Return a dictionary containing logging information.

        Process the message content and headers with the hooks and return
        the logging environment.
        """
        # Some calls, such as timeouts on consume loops, do not
        # contain headers or body.
        headers = headers or {}
        body = body or {}

        content = {
            'headers': headers,
        }
        message_type = headers.get('message-type')

        for target_message_type, hook in self.hooks:
            if not message_type == target_message_type:
                continue
            content.update(hook(headers, body))

        return content

    @staticmethod
    def hook_gitlab(_, body):
        """Return logging variables for a Gitlab message."""
        pipeline_variables = misc.key_value_list_to_dict(
            misc.get_nested_key(body, 'object_attributes/variables')
        )
        object_kind = body.get('object_kind')

        content = {
            'object_kind': object_kind,
            'user': {
                'id': misc.get_nested_key(body, 'user/id'),
            },
        }

        # This works for every object kind except build
        if 'project' in body:
            content.update({
                'project': {
                    'id': misc.get_nested_key(body, 'project/id'),
                    'path_with_namespace': misc.get_nested_key(
                        body, 'project/path_with_namespace'
                    ),
                }
            })

        if object_kind == 'merge_request':
            content.update({
                'merge_request': {
                    'iid': misc.get_nested_key(body, 'object_attributes/iid'),
                },
            })
        if object_kind == 'note':
            content.update({
                'merge_request': {
                    'iid': misc.get_nested_key(body, 'merge_request/iid'),
                },
                'note': {
                    'noteable_type': misc.get_nested_key(body, 'object_attributes/noteable_type'),
                }
            })
        if object_kind == 'pipeline':
            content.update({
                'merge_request': {
                    'iid': misc.get_nested_key(body, 'merge_request/iid'),
                },
                'pipeline': {
                    'id': misc.get_nested_key(body, 'object_attributes/id'),
                    'variables': {
                        'mr_id': pipeline_variables.get('mr_id'),
                        'mr_project_id': pipeline_variables.get('mr_project_id'),
                    }
                }
            })
        if object_kind == 'build':
            content.update({
                'project': {
                    'id': body.get('project_id'),
                    'name': body.get('project_name'),
                },
                'build': {
                    'id': body.get('build_id'),
                },
                'pipeline': {
                    'id': body.get('pipeline_id'),
                }
            })
        if object_kind == 'push':
            content.update({
                'ref': body.get('ref'),
                'user': {
                    'id': body.get('user_id')
                }
            })

        return {'gitlab': content}

    @staticmethod
    def hook_datawarehouse(_, body):
        """Return logging variables for a DW message."""
        return {
            'datawarehouse': {
                'status': body.get('status'),
                'object_type': body.get('object_type'),
                'id': misc.get_nested_key(body, 'object/id'),
                'iid': misc.get_nested_key(body, 'object/misc/iid'),
                'misc': body.get('misc'),
            }
        }
